{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "(sphx_glr_tutorial_auto_scheduler_matmul_x86)=\n",
        "# 使用自动调度优化运算\n",
        "\n",
        "**作者**: [Lianmin Zheng](https://github.com/merrymercy)，[Chengfan Jia](https://github.com/jcf94/)\n",
        "\n",
        "在本教程中，将展示 TVM 的自动调度功能如何在不需要编写自定义模板的情况下找到最佳调度。\n",
        "\n",
        "与基于模板的 [AutoTVM](autotvm_matmul_x86) 不同，后者依赖于手动模板来定义搜索空间，而自动调度器不需要任何模板。\n",
        "\n",
        "用户只需要编写计算声明，而不需要任何调度命令或模板。自动调度器可以自动生成大的搜索空间，并在空间中找到好的调度。\n",
        "\n",
        "本教程中以矩阵乘法为例。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import tvm\n",
        "from tvm import te, auto_scheduler"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 定义矩阵乘法\n",
        "\n",
        "首先，定义带有偏置加法的矩阵乘法。注意，这使用了 TVM 张量表达式语言中的标准运算。主要的区别是在函数定义的顶部使用了 {func}`tvm.auto_scheduler.register_workload` 装饰器。该函数应该返回输入/输出张量的列表。从这些张量中，自动调度器可以得到整个计算图。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "@auto_scheduler.register_workload  # 注意 auto_scheduler 装饰器\n",
        "def matmul_add(N, L, M, dtype):\n",
        "    A = te.placeholder((N, L), name=\"A\", dtype=dtype)\n",
        "    B = te.placeholder((L, M), name=\"B\", dtype=dtype)\n",
        "    C = te.placeholder((N, M), name=\"C\", dtype=dtype)\n",
        "\n",
        "    k = te.reduce_axis((0, L), name=\"k\")\n",
        "    matmul = te.compute(\n",
        "        (N, M),\n",
        "        lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),\n",
        "        name=\"matmul\",\n",
        "        attrs={\"layout_free_placeholders\": [B]},  # 启用张量 B 的自动布局转换\n",
        "    )\n",
        "    out = te.compute((N, M), lambda i, j: matmul[i, j] + C[i, j], name=\"out\")\n",
        "    return [A, B, C, out]"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 创建搜索任务\n",
        "\n",
        "在定义了函数之后，现在可以为 `auto_scheduler` 创建任务来进行搜索。指定矩阵乘法的特殊参数，在这个例子中，是对 $1024 \\times 1024$ 大小的正方形矩阵的乘法。然后使用 `N=L=M=1024` 和 `dtype=\"float32\"` 创建搜索任务。\n",
        "\n",
        "```{admonition} 用自定义目标提高性能\n",
        "为了使 TVM 能够充分利用特定的硬件平台，手动指定你的 CPU 能力。例如：\n",
        "\n",
        "- 用 ``llvm -mcpu=core-avx2`` 替换下面的 ``llvm``，以启用 AVX2\n",
        "- 用 ``llvm -mcpu=skylake-avx512`` 替换下面的 ``llvm``，以启用 AVX-512\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Computational DAG:\n",
            "A = PLACEHOLDER [1024, 1024]\n",
            "B = PLACEHOLDER [1024, 1024]\n",
            "matmul(i, j) += (A[i, k]*B[k, j])\n",
            "C = PLACEHOLDER [1024, 1024]\n",
            "out(i, j) = (matmul[i, j] + C[i, j])\n",
            "\n"
          ]
        }
      ],
      "source": [
        "target = tvm.target.Target(\"llvm\")\n",
        "N = L = M = 1024\n",
        "task = tvm.auto_scheduler.SearchTask(func=matmul_add, args=(N, L, M, \"float32\"), target=target)\n",
        "\n",
        "# 检查计算图\n",
        "print(\"Computational DAG:\")\n",
        "print(task.compute_dag)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 为自动调度设置参数\n",
        "\n",
        "下一步，为自动调度设置参数。\n",
        "\n",
        "* `num_measure_trials` 是在搜索过程中可以使用的测量试验的数量。为了快速演示，在本教程中只做了 10 次试验。在实践中，1000 是个很好的搜索收敛值。你可以根据你的时间预算做更多的试验。\n",
        "* 此外，使用 {any}`RecordToFile <auto_scheduler.RecordToFile>` 来 log 测量记录到 `matmul.json` 文件中。这些测量记录可以用来查询历史最好的，恢复搜索，并在以后做更多的分析。\n",
        "* 查阅 {any}`TuningOptions <auto_scheduler.TuningOptions>` 了解参数的更多信息。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "log_file = \"matmul.json\"\n",
        "tune_option = auto_scheduler.TuningOptions(\n",
        "    num_measure_trials=10,\n",
        "    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],\n",
        "    verbose=2,\n",
        ")"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 运行搜索\n",
        "\n",
        "现在把所有的输入准备好。很简单，不是吗？可以启动搜索，让自动调度发挥它的魔力。经过一些测量试验后，可以从日志文件中加载最佳调度并加以应用。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "----------------------------------------------------------------------\n",
            "------------------------------  [ Search ]\n",
            "----------------------------------------------------------------------\n",
            "Generate Sketches\t\t#s: 3\n",
            "Sample Initial Population\t#s: 2012\tfail_ct: 4\tTime elapsed: 2.99\n",
            "GA Iter: 0\tMax score: 0.9999\tMin score: 0.9356\t#Pop: 128\t#M+: 0\t#M-: 0\n",
            "GA Iter: 4\tMax score: 1.0000\tMin score: 0.9877\t#Pop: 128\t#M+: 1384\t#M-: 79\n",
            "EvolutionarySearch\t\t#s: 128\tTime elapsed: 12.76\n",
            "----------------------------------------------------------------------\n",
            "------------------------------  [ Measure ]\n",
            "----------------------------------------------------------------------\n",
            "Get 10 programs to measure:\n",
            "..........**********\n",
            "==================================================\n",
            "No: 1\tGFLOPS: 125.83 / 125.83\tresults: MeasureResult(cost:[0.0171], error_no:0, all_cost:0.53, Tstamp:1679472852.34)\n",
            "==================================================\n",
            "Placeholder: A, B, C\n",
            "parallel i.0@j.0@i.1@ (0,131072)\n",
            "  matmul auto_unroll: 64\n",
            "  for k.0 (0,128)\n",
            "    for i.2 (0,4)\n",
            "      for k.1 (0,8)\n",
            "        for i.3 (0,2)\n",
            "          matmul = ...\n",
            "  for i.2 (0,8)\n",
            "    out = ...\n",
            "\n",
            "==================================================\n",
            "No: 2\tGFLOPS: 6.94 / 125.83\tresults: MeasureResult(cost:[0.3098], error_no:0, all_cost:2.33, Tstamp:1679472853.71)\n",
            "==================================================\n",
            "Placeholder: A, B, C\n",
            "parallel i.0@j.0@ (0,2)\n",
            "  matmul auto_unroll: 64\n",
            "  for i.1 (0,4)\n",
            "    for j.1 (0,256)\n",
            "      for k.0 (0,128)\n",
            "        for i.2 (0,4)\n",
            "          for k.1 (0,8)\n",
            "            for i.3 (0,32)\n",
            "              vectorize j.3 (0,4)\n",
            "                matmul = ...\n",
            "  for i.1 (0,512)\n",
            "    for j.1 (0,1024)\n",
            "      out = ...\n",
            "\n",
            "==================================================\n",
            "No: 3\tGFLOPS: 256.41 / 256.41\tresults: MeasureResult(cost:[0.0084], error_no:0, all_cost:0.69, Tstamp:1679472854.15)\n",
            "==================================================\n",
            "Placeholder: A, B, C\n",
            "parallel i.0@j.0@ (0,8192)\n",
            "  for i.1 (0,2)\n",
            "    matmul auto_unroll: 64\n",
            "    for k.0 (0,512)\n",
            "      for k.1 (0,2)\n",
            "        for i.3 (0,4)\n",
            "          vectorize j.3 (0,16)\n",
            "            matmul = ...\n",
            "    for i.2 (0,4)\n",
            "      vectorize j.2 (0,16)\n",
            "        out = ...\n",
            "\n",
            "==================================================\n",
            "No: 4\tGFLOPS: 82.04 / 256.41\tresults: MeasureResult(cost:[0.0262], error_no:0, all_cost:0.70, Tstamp:1679472854.52)\n",
            "==================================================\n",
            "Placeholder: A, B, C\n",
            "parallel i.0@j.0@ (0,16)\n",
            "  matmul auto_unroll: 64\n",
            "  for i.1 (0,4)\n",
            "    for j.1 (0,8)\n",
            "      for k.0 (0,32)\n",
            "        for j.2 (0,16)\n",
            "          for k.1 (0,32)\n",
            "            for i.3 (0,32)\n",
            "              vectorize j.3 (0,4)\n",
            "                matmul = ...\n",
            "  for i.1 (0,128)\n",
            "    for j.1 (0,512)\n",
            "      out = ...\n",
            "\n",
            "==================================================\n",
            "No: 5\tGFLOPS: 250.11 / 256.41\tresults: MeasureResult(cost:[0.0086], error_no:0, all_cost:0.83, Tstamp:1679472855.12)\n",
            "==================================================\n",
            "Placeholder: A, B, C\n",
            "parallel i.0@j.0@ (0,1024)\n",
            "  for j.1 (0,4)\n",
            "    for k.0 (0,256)\n",
            "      for i.2 (0,8)\n",
            "        for j.2 (0,4)\n",
            "          for k.1 (0,4)\n",
            "            for i.3 (0,2)\n",
            "              vectorize j.3 (0,4)\n",
            "                matmul = ...\n",
            "    for i.2 (0,16)\n",
            "      vectorize j.2 (0,16)\n",
            "        out = ...\n",
            "\n",
            "==================================================\n",
            "No: 6\tGFLOPS: 160.68 / 256.41\tresults: MeasureResult(cost:[0.0134], error_no:0, all_cost:0.61, Tstamp:1679472855.42)\n",
            "==================================================\n",
            "Placeholder: A, B, C\n",
            "matmul auto_unroll: 16\n",
            "parallel i.0@j.0@i.1@j.1@ (0,8192)\n",
            "  for k.0 (0,256)\n",
            "    for i.2 (0,2)\n",
            "      for k.1 (0,4)\n",
            "        for i.3 (0,2)\n",
            "          for j.3 (0,32)\n",
            "            matmul = ...\n",
            "parallel i (0,1024)\n",
            "  for j (0,1024)\n",
            "    out = ...\n",
            "\n",
            "==================================================\n",
            "No: 7\tGFLOPS: 48.42 / 256.41\tresults: MeasureResult(cost:[0.0444], error_no:0, all_cost:0.81, Tstamp:1679472855.74)\n",
            "==================================================\n",
            "Placeholder: A, B, C\n",
            "parallel i.0@j.0@ (0,128)\n",
            "  matmul auto_unroll: 512\n",
            "  for i.1 (0,8)\n",
            "    for k.0 (0,32)\n",
            "      for i.2 (0,32)\n",
            "        for j.2 (0,32)\n",
            "          for k.1 (0,32)\n",
            "            matmul = ...\n",
            "  for i.1 (0,256)\n",
            "    for j.1 (0,32)\n",
            "      out = ...\n",
            "\n",
            "==================================================\n",
            "No: 8\tGFLOPS: 20.51 / 256.41\tresults: MeasureResult(cost:[0.1047], error_no:0, all_cost:0.69, Tstamp:1679472856.31)\n",
            "==================================================\n",
            "Placeholder: A, B, C\n",
            "parallel i.0@j.0@i.1@j.1@ (0,128)\n",
            "  for k.0 (0,64)\n",
            "    for j.2 (0,1024)\n",
            "      for k.1 (0,16)\n",
            "        for i.3 (0,8)\n",
            "          matmul = ...\n",
            "parallel i (0,1024)\n",
            "  for j (0,1024)\n",
            "    out = ...\n",
            "\n",
            "==================================================\n",
            "No: 9\tGFLOPS: 90.08 / 256.41\tresults: MeasureResult(cost:[0.0239], error_no:0, all_cost:0.52, Tstamp:1679472856.66)\n",
            "==================================================\n",
            "Placeholder: A, B, C\n",
            "matmul auto_unroll: 16\n",
            "parallel i.0@j.0@i.1@ (0,4096)\n",
            "  for j.1 (0,8)\n",
            "    for k.0 (0,32)\n",
            "      for i.2 (0,4)\n",
            "        for k.1 (0,32)\n",
            "          for i.3 (0,8)\n",
            "            matmul = ...\n",
            "parallel i (0,1024)\n",
            "  for j (0,1024)\n",
            "    out = ...\n",
            "\n",
            "==================================================\n",
            "No: 10\tGFLOPS: 4.71 / 256.41\tresults: MeasureResult(cost:[0.4557], error_no:0, all_cost:2.05, Tstamp:1679472858.62)\n",
            "==================================================\n",
            "Placeholder: A, B, C\n",
            "matmul auto_unroll: 64\n",
            "parallel i.0@j.0@i.1@j.1@ (0,2048)\n",
            "  for k.0 (0,1024)\n",
            "    for i.2 (0,512)\n",
            "      matmul = ...\n",
            "parallel i (0,1024)\n",
            "  for j (0,1024)\n",
            "    out = ...\n",
            "\n",
            "Time elapsed for measurement: 11.32 s\n",
            "----------------------------------------------------------------------\n",
            "------------------------------  [ Done ]\n",
            "----------------------------------------------------------------------\n"
          ]
        }
      ],
      "source": [
        "# 运行 auto-tuning (search)\n",
        "task.tune(tune_option)\n",
        "# 应用最优 schedule\n",
        "sch, args = task.apply_best(log_file)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 检查优化后的调度\n",
        "\n",
        "可以 lower 调度，看看自动调度后的 IR。自动调度器正确地进行了优化，包括多级平铺（tiling）、布局转换（layout transformation）、并行化（parallelization）、矢量化（vectorization）、解卷（unrolling）和运算符融合（operator fusion）。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(\n",
              "        A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        C: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "        out: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1024</span>, <span style=\"color: #008000\">1024</span>), <span style=\"color: #BA2121\">&quot;float32&quot;</span>),\n",
              "    ):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr(\n",
              "            {\n",
              "                <span style=\"color: #BA2121\">&quot;from_legacy_te_schedule&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;global_symbol&quot;</span>: <span style=\"color: #BA2121\">&quot;main&quot;</span>,\n",
              "                <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>,\n",
              "            }\n",
              "        )\n",
              "        auto_scheduler_layout_transform <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>allocate([<span style=\"color: #008000\">1048576</span>], <span style=\"color: #BA2121\">&quot;float32&quot;</span>, <span style=\"color: #BA2121\">&quot;global&quot;</span>)\n",
              "        auto_scheduler_layout_transform_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer(\n",
              "            (<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>auto_scheduler_layout_transform\n",
              "        )\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> ax0_ax1_fused_ax2_fused <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>parallel(<span style=\"color: #008000\">256</span>):\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> ax4, ax6, ax7 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">512</span>, <span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">4</span>):\n",
              "                B_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>B<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                auto_scheduler_layout_transform_1[\n",
              "                    ax0_ax1_fused_ax2_fused <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4096</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> ax4 <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">8</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> ax6 <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> ax7\n",
              "                ] <span style=\"color: #AA22FF; font-weight: bold\">=</span> B_1[ax4 <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">2048</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> ax6 <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> ax0_ax1_fused_ax2_fused <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> ax7]\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i_outer_outer_j_outer_outer_fused <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>parallel(<span style=\"color: #008000\">2048</span>):\n",
              "            matmul <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>allocate([<span style=\"color: #008000\">32</span>], <span style=\"color: #BA2121\">&quot;float32&quot;</span>, <span style=\"color: #BA2121\">&quot;global&quot;</span>)\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> j_outer_inner <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">16</span>):\n",
              "                matmul_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">32</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>matmul)\n",
              "                matmul_1[<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">4</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0</span>), <span style=\"color: #008000\">4</span>)\n",
              "                matmul_1[<span style=\"color: #008000\">4</span>:<span style=\"color: #008000\">8</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0</span>), <span style=\"color: #008000\">4</span>)\n",
              "                matmul_1[<span style=\"color: #008000\">8</span>:<span style=\"color: #008000\">12</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0</span>), <span style=\"color: #008000\">4</span>)\n",
              "                matmul_1[<span style=\"color: #008000\">12</span>:<span style=\"color: #008000\">16</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0</span>), <span style=\"color: #008000\">4</span>)\n",
              "                matmul_1[<span style=\"color: #008000\">16</span>:<span style=\"color: #008000\">20</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0</span>), <span style=\"color: #008000\">4</span>)\n",
              "                matmul_1[<span style=\"color: #008000\">20</span>:<span style=\"color: #008000\">24</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0</span>), <span style=\"color: #008000\">4</span>)\n",
              "                matmul_1[<span style=\"color: #008000\">24</span>:<span style=\"color: #008000\">28</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0</span>), <span style=\"color: #008000\">4</span>)\n",
              "                matmul_1[<span style=\"color: #008000\">28</span>:<span style=\"color: #008000\">32</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0</span>), <span style=\"color: #008000\">4</span>)\n",
              "                <span style=\"color: #008000; font-weight: bold\">for</span> k_outer <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">512</span>):\n",
              "                    cse_var_3: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        i_outer_outer_j_outer_outer_fused <span style=\"color: #AA22FF; font-weight: bold\">//</span> <span style=\"color: #008000\">16</span> <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">8192</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> k_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">2</span>\n",
              "                    )\n",
              "                    cse_var_2: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        i_outer_outer_j_outer_outer_fused <span style=\"color: #AA22FF; font-weight: bold\">%</span> <span style=\"color: #008000\">16</span> <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">65536</span>\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> j_outer_inner <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4096</span>\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> k_outer <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">8</span>\n",
              "                    )\n",
              "                    cse_var_1: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">4</span>\n",
              "                    A_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>A<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                    matmul_1[<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">4</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        matmul_1[<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">4</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_3], <span style=\"color: #008000\">4</span>)\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">*</span> auto_scheduler_layout_transform_1[cse_var_2 : cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">4</span>]\n",
              "                    )\n",
              "                    matmul_1[<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">4</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        matmul_1[<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">4</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">1</span>], <span style=\"color: #008000\">4</span>)\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">*</span> auto_scheduler_layout_transform_1[cse_var_1 : cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">4</span>]\n",
              "                    )\n",
              "                    matmul_1[<span style=\"color: #008000\">4</span>:<span style=\"color: #008000\">8</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        matmul_1[<span style=\"color: #008000\">4</span>:<span style=\"color: #008000\">8</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">1024</span>], <span style=\"color: #008000\">4</span>)\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">*</span> auto_scheduler_layout_transform_1[cse_var_2 : cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">4</span>]\n",
              "                    )\n",
              "                    matmul_1[<span style=\"color: #008000\">4</span>:<span style=\"color: #008000\">8</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        matmul_1[<span style=\"color: #008000\">4</span>:<span style=\"color: #008000\">8</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">1025</span>], <span style=\"color: #008000\">4</span>)\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">*</span> auto_scheduler_layout_transform_1[cse_var_1 : cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">4</span>]\n",
              "                    )\n",
              "                    matmul_1[<span style=\"color: #008000\">8</span>:<span style=\"color: #008000\">12</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        matmul_1[<span style=\"color: #008000\">8</span>:<span style=\"color: #008000\">12</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">2048</span>], <span style=\"color: #008000\">4</span>)\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">*</span> auto_scheduler_layout_transform_1[cse_var_2 : cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">4</span>]\n",
              "                    )\n",
              "                    matmul_1[<span style=\"color: #008000\">8</span>:<span style=\"color: #008000\">12</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        matmul_1[<span style=\"color: #008000\">8</span>:<span style=\"color: #008000\">12</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">2049</span>], <span style=\"color: #008000\">4</span>)\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">*</span> auto_scheduler_layout_transform_1[cse_var_1 : cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">4</span>]\n",
              "                    )\n",
              "                    matmul_1[<span style=\"color: #008000\">12</span>:<span style=\"color: #008000\">16</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        matmul_1[<span style=\"color: #008000\">12</span>:<span style=\"color: #008000\">16</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">3072</span>], <span style=\"color: #008000\">4</span>)\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">*</span> auto_scheduler_layout_transform_1[cse_var_2 : cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">4</span>]\n",
              "                    )\n",
              "                    matmul_1[<span style=\"color: #008000\">12</span>:<span style=\"color: #008000\">16</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        matmul_1[<span style=\"color: #008000\">12</span>:<span style=\"color: #008000\">16</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">3073</span>], <span style=\"color: #008000\">4</span>)\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">*</span> auto_scheduler_layout_transform_1[cse_var_1 : cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">4</span>]\n",
              "                    )\n",
              "                    matmul_1[<span style=\"color: #008000\">16</span>:<span style=\"color: #008000\">20</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        matmul_1[<span style=\"color: #008000\">16</span>:<span style=\"color: #008000\">20</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">4096</span>], <span style=\"color: #008000\">4</span>)\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">*</span> auto_scheduler_layout_transform_1[cse_var_2 : cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">4</span>]\n",
              "                    )\n",
              "                    matmul_1[<span style=\"color: #008000\">16</span>:<span style=\"color: #008000\">20</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        matmul_1[<span style=\"color: #008000\">16</span>:<span style=\"color: #008000\">20</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">4097</span>], <span style=\"color: #008000\">4</span>)\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">*</span> auto_scheduler_layout_transform_1[cse_var_1 : cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">4</span>]\n",
              "                    )\n",
              "                    matmul_1[<span style=\"color: #008000\">20</span>:<span style=\"color: #008000\">24</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        matmul_1[<span style=\"color: #008000\">20</span>:<span style=\"color: #008000\">24</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">5120</span>], <span style=\"color: #008000\">4</span>)\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">*</span> auto_scheduler_layout_transform_1[cse_var_2 : cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">4</span>]\n",
              "                    )\n",
              "                    matmul_1[<span style=\"color: #008000\">20</span>:<span style=\"color: #008000\">24</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        matmul_1[<span style=\"color: #008000\">20</span>:<span style=\"color: #008000\">24</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">5121</span>], <span style=\"color: #008000\">4</span>)\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">*</span> auto_scheduler_layout_transform_1[cse_var_1 : cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">4</span>]\n",
              "                    )\n",
              "                    matmul_1[<span style=\"color: #008000\">24</span>:<span style=\"color: #008000\">28</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        matmul_1[<span style=\"color: #008000\">24</span>:<span style=\"color: #008000\">28</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">6144</span>], <span style=\"color: #008000\">4</span>)\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">*</span> auto_scheduler_layout_transform_1[cse_var_2 : cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">4</span>]\n",
              "                    )\n",
              "                    matmul_1[<span style=\"color: #008000\">24</span>:<span style=\"color: #008000\">28</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        matmul_1[<span style=\"color: #008000\">24</span>:<span style=\"color: #008000\">28</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">6145</span>], <span style=\"color: #008000\">4</span>)\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">*</span> auto_scheduler_layout_transform_1[cse_var_1 : cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">4</span>]\n",
              "                    )\n",
              "                    matmul_1[<span style=\"color: #008000\">28</span>:<span style=\"color: #008000\">32</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        matmul_1[<span style=\"color: #008000\">28</span>:<span style=\"color: #008000\">32</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">7168</span>], <span style=\"color: #008000\">4</span>)\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">*</span> auto_scheduler_layout_transform_1[cse_var_2 : cse_var_2 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">4</span>]\n",
              "                    )\n",
              "                    matmul_1[<span style=\"color: #008000\">28</span>:<span style=\"color: #008000\">32</span>] <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        matmul_1[<span style=\"color: #008000\">28</span>:<span style=\"color: #008000\">32</span>]\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Broadcast(A_1[cse_var_3 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">7169</span>], <span style=\"color: #008000\">4</span>)\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">*</span> auto_scheduler_layout_transform_1[cse_var_1 : cse_var_1 <span style=\"color: #AA22FF; font-weight: bold\">+</span> <span style=\"color: #008000\">4</span>]\n",
              "                    )\n",
              "                <span style=\"color: #008000; font-weight: bold\">for</span> i_inner, j_inner <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">8</span>, <span style=\"color: #008000\">4</span>):\n",
              "                    cse_var_4: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>int32 <span style=\"color: #AA22FF; font-weight: bold\">=</span> (\n",
              "                        i_outer_outer_j_outer_outer_fused <span style=\"color: #AA22FF; font-weight: bold\">//</span> <span style=\"color: #008000\">16</span> <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">8192</span>\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> i_inner <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">1024</span>\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> i_outer_outer_j_outer_outer_fused <span style=\"color: #AA22FF; font-weight: bold\">%</span> <span style=\"color: #008000\">16</span> <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">64</span>\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> j_outer_inner <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4</span>\n",
              "                        <span style=\"color: #AA22FF; font-weight: bold\">+</span> j_inner\n",
              "                    )\n",
              "                    out_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>out<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                    C_1 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">1048576</span>,), data<span style=\"color: #AA22FF; font-weight: bold\">=</span>C<span style=\"color: #AA22FF; font-weight: bold\">.</span>data)\n",
              "                    out_1[cse_var_4] <span style=\"color: #AA22FF; font-weight: bold\">=</span> matmul_1[i_inner <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> j_inner] <span style=\"color: #AA22FF; font-weight: bold\">+</span> C_1[cse_var_4]\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "mod = tvm.lower(sch, args, simple_mode=True)\n",
        "mod.show()"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 检查正确性并评估性能\n",
        "\n",
        "建立二进制文件，并检查其正确性（correctness）和性能（performance）。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Execution time of this operator: 5.362 ms\n"
          ]
        }
      ],
      "source": [
        "func = tvm.build(sch, args, target)\n",
        "a_np = np.random.uniform(size=(N, L)).astype(np.float32)\n",
        "b_np = np.random.uniform(size=(L, M)).astype(np.float32)\n",
        "c_np = np.random.uniform(size=(N, M)).astype(np.float32)\n",
        "out_np = a_np.dot(b_np) + c_np\n",
        "\n",
        "dev = tvm.cpu()\n",
        "a_tvm = tvm.nd.array(a_np, device=dev)\n",
        "b_tvm = tvm.nd.array(b_np, device=dev)\n",
        "c_tvm = tvm.nd.array(c_np, device=dev)\n",
        "out_tvm = tvm.nd.empty(out_np.shape, device=dev)\n",
        "func(a_tvm, b_tvm, c_tvm, out_tvm)\n",
        "\n",
        "# Check results\n",
        "np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)\n",
        "\n",
        "# Evaluate execution time.\n",
        "evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)\n",
        "print(\n",
        "    \"Execution time of this operator: %.3f ms\"\n",
        "    % (np.median(evaluator(a_tvm, b_tvm, c_tvm, out_tvm).results) * 1000)\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 使用纪录文件\n",
        "\n",
        "在搜索过程中，所有的测量记录都被 log 到记录文件 `matmul.json`。这些测量记录可以用来重新应用搜索结果，恢复搜索，并进行其他分析。\n",
        "\n",
        "这里有一个例子，我们从一个文件中加载最佳调度，并打印出等效的 python 调度 API。这可以用于调试和学习自动调度的行为。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Equivalent python schedule:\n",
            "matmul_i, matmul_j, matmul_k = tuple(matmul.op.axis) + tuple(matmul.op.reduce_axis)\n",
            "out_i, out_j = tuple(out.op.axis) + tuple(out.op.reduce_axis)\n",
            "matmul_i_o_i, matmul_i_i = s[matmul].split(matmul_i, factor=1)\n",
            "matmul_i_o_o_i, matmul_i_o_i = s[matmul].split(matmul_i_o_i, factor=8)\n",
            "matmul_i_o_o_o, matmul_i_o_o_i = s[matmul].split(matmul_i_o_o_i, factor=1)\n",
            "matmul_j_o_i, matmul_j_i = s[matmul].split(matmul_j, factor=4)\n",
            "matmul_j_o_o_i, matmul_j_o_i = s[matmul].split(matmul_j_o_i, factor=1)\n",
            "matmul_j_o_o_o, matmul_j_o_o_i = s[matmul].split(matmul_j_o_o_i, factor=16)\n",
            "matmul_k_o, matmul_k_i = s[matmul].split(matmul_k, factor=2)\n",
            "s[matmul].reorder(matmul_i_o_o_o, matmul_j_o_o_o, matmul_i_o_o_i, matmul_j_o_o_i, matmul_k_o, matmul_i_o_i, matmul_j_o_i, matmul_k_i, matmul_i_i, matmul_j_i)\n",
            "out_i_o_i, out_i_i = s[out].split(out_i, factor=8)\n",
            "out_i_o_o, out_i_o_i = s[out].split(out_i_o_i, factor=1)\n",
            "out_j_o_i, out_j_i = s[out].split(out_j, factor=4)\n",
            "out_j_o_o, out_j_o_i = s[out].split(out_j_o_i, factor=16)\n",
            "s[out].reorder(out_i_o_o, out_j_o_o, out_i_o_i, out_j_o_i, out_i_i, out_j_i)\n",
            "s[matmul].compute_at(s[out], out_j_o_i)\n",
            "out_i_o_o_j_o_o_fused = s[out].fuse(out_i_o_o, out_j_o_o)\n",
            "s[out].parallel(out_i_o_o_j_o_o_fused)\n",
            "s[matmul].pragma(matmul_i_o_o_o, \"auto_unroll_max_step\", 512)\n",
            "s[matmul].pragma(matmul_i_o_o_o, \"unroll_explicit\", True)\n",
            "s[matmul].vectorize(matmul_j_i)\n",
            "\n"
          ]
        }
      ],
      "source": [
        "print(\"Equivalent python schedule:\")\n",
        "print(task.print_best(log_file))"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "更复杂的例子是恢复搜索。在这种情况下，需要自己创建搜索策略和成本模型，并通过日志文件恢复搜索策略和成本模型（cost model）的状态。在下面的例子中，恢复了状态并做了更多的 5 次试验。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Resume search:\n",
            "----------------------------------------------------------------------\n",
            "------------------------------  [ Call init-search callbacks ]\n",
            "----------------------------------------------------------------------\n",
            "SearchPolicy: Loaded 25 measurement records from matmul.json for [\"matmul_add\", 1024, 1024, 1024, \"float32\"]\n",
            "----------------------------------------------------------------------\n",
            "------------------------------  [ Search ]\n",
            "----------------------------------------------------------------------\n",
            "Generate Sketches\t\t#s: 3\n",
            "Sample Initial Population\t#s: 2013\tfail_ct: 6\tTime elapsed: 2.74\n",
            "GA Iter: 0\tMax score: 0.9995\tMin score: 0.9315\t#Pop: 128\t#M+: 0\t#M-: 0\n",
            "GA Iter: 4\tMax score: 0.9998\tMin score: 0.9862\t#Pop: 128\t#M+: 1373\t#M-: 69\n",
            "EvolutionarySearch\t\t#s: 128\tTime elapsed: 12.14\n",
            "----------------------------------------------------------------------\n",
            "------------------------------  [ Measure ]\n",
            "----------------------------------------------------------------------\n",
            "Get 5 programs to measure:\n",
            ".....*****\n",
            "Time elapsed for measurement: 6.14 s\n",
            "----------------------------------------------------------------------\n",
            "------------------------------  [ Done ]\n",
            "----------------------------------------------------------------------\n"
          ]
        }
      ],
      "source": [
        "def resume_search(task, log_file):\n",
        "    print(\"Resume search:\")\n",
        "    cost_model = auto_scheduler.XGBModel()\n",
        "    cost_model.update_from_file(log_file)\n",
        "    search_policy = auto_scheduler.SketchPolicy(\n",
        "        task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]\n",
        "    )\n",
        "    tune_option = auto_scheduler.TuningOptions(\n",
        "        num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]\n",
        "    )\n",
        "    task.tune(tune_option, search_policy=search_policy)\n",
        "\n",
        "resume_search(task, log_file)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 最后说明和总结\n",
        "\n",
        "在本教程中，已经展示了如何使用 TVM 自动调度器来自动优化矩阵乘法，而不需要指定搜索模板。它结束了一系列从张量表达式（Tensor Expression，简称 TE）语言开始的例子，展示了 TVM 如何优化算子计算。"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "ai",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.9"
    },
    "vscode": {
      "interpreter": {
        "hash": "e0af55fbb1c4b4e8ca009f3673b968438b459a89daa1170f52b672ab74da765c"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
